package org.apache.spark.ml.feature

import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.shared.HasOutputCols
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, MLWritable}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.{MetadataBuilder, StructField, StructType}

class ColumnFilter(override val uid: String) extends Transformer with HasOutputCols with MLWritable with DefaultParamsWritable{

  def this() = this(Identifiable.randomUID("ColumnFilter"))

  def setFilterColumns(value: Array[String]): this.type = set(outputCols, value)

  override def transform(dataset: Dataset[_]): DataFrame = {
    dataset.select(${outputCols}.map(dataset.col): _*)
  }

  override def transformSchema(schema: StructType): StructType = {
    StructType(schema.fields.filter(col => ${outputCols}.contains(col.name)).map(item => {
      StructField(item.name, item.dataType, item.nullable, new MetadataBuilder().withMetadata(item.metadata).putString("aaaa", "xxx").build())
    }))
  }

  override def copy(extra: ParamMap): ColumnFilter = defaultCopy(extra)

}

object ColumnFilter extends DefaultParamsReadable[ColumnFilter] {

    override def load(path: String): ColumnFilter = super.load(path)

}